package com.enigmastation.classifier.impl;

import java.util.Collections;
import java.util.Map;
import java.util.Set;

import com.enigmastation.classifier.CategoryIncrement;
import com.enigmastation.classifier.Classifier;
import com.enigmastation.classifier.ClassifierDataModelFactory;
import com.enigmastation.classifier.ClassifierListener;
import com.enigmastation.classifier.FeatureIncrement;
import com.enigmastation.extractors.WordLister;
import com.enigmastation.extractors.impl.StemmingWordLister;
import com.google.common.collect.Sets;

/**
 * This is a simple Bayesian calculation class. It was ported from Python
 * contained in the book
 * "<a href="http://www.oreilly.com/catalog/9780596529321/index
 * .html">Programming Collective Intelligence</a>," by Toby Segaran.
 * 
 * @author <a href="mailto:joeo@enigmastation.com">Joseph B. Ottinger</a>
 * @version $Revision: 36 $
 */
public class ClassifierImpl implements Classifier {

	/**
	 * 
	 */
	private static final long serialVersionUID = -3349599753555172369L;
	protected WordLister wordLister = null;
	private Set<ClassifierListener> trainingListeners;
	private ClassifierDataModelFactory classifierDataModelFactory;
	private boolean initialized = false;

	public ClassifierImpl() {
		init();
	}

	public synchronized void init() {
		if (!initialized) {
			if (getClassifierDataModelFactory() == null) {
				setClassifierDataModelFactory(new BasicClassifierDataModelFactory());
			}

			if (wordLister == null) {
				wordLister = new StemmingWordLister();
			}

			initialized = true;
		}
	}

	public ClassifierDataModelFactory getClassifierDataModelFactory() {
		return classifierDataModelFactory;
	}

	public void setClassifierDataModelFactory(
			ClassifierDataModelFactory classifierDataModelFactory) {
		/*Sorry but this is stupid
		if (getClassifierDataModelFactory() != null) {
			throw new IllegalStateException(
					"Cannot set ModelFactory twice; old type is "
							+ getClassifierDataModelFactory().getClass()
									.getName());
		}*/
		this.classifierDataModelFactory = classifierDataModelFactory;
	}

	public WordLister getWordLister() {
		return wordLister;
	}

	public void setWordLister(WordLister wordLister) {
		this.wordLister = wordLister;
	}

	public void addListener(ClassifierListener listener) {
		if (trainingListeners == null)
			trainingListeners = Sets.newHashSet();
		trainingListeners.add(listener);
	}

	/**
	 * Increase the count of a feature/category pair.
	 * <p/>
	 * Direct port from Segaran's book, including method name
	 * 
	 * @param feature
	 *            the feature (the 'word')
	 * @param category
	 *            the category
	 */

	public void incf(String feature, String category) {

		// Map<String,Map<String,Integer>> featureMap = this.categoryFeatureMap;
		// Map<String,Integer> fm = featureMap.get(feature);
		Map<String, Integer> fm = this.classifierDataModelFactory
				.getFeatureMap(feature);
		if (fm == null) {
			// throw new
			// IllegalStateException("You must either supply all classifiers in a map or supply a factory");
			throw new IllegalStateException(
					"You must be able to create a feature map for feature '"+feature+"'");
		}

		incrementCategory(fm, category);

		FeatureIncrement fi = null;
		if (trainingListeners != null) {
			for (ClassifierListener l : trainingListeners) {
				if (fi == null) {
					fi = new FeatureIncrement(feature, category,
							fm.get(category));
				}
				l.handleFeatureUpdate(fi);
			}
		}

	}

	/**
	 * Increase the count of a category. Direct port from Segaran's book,
	 * including method name
	 * 
	 * @param category
	 *            the category to increment
	 */
	public void incc(String category) {

		incrementCategory(getCategoryDocCount(), category);

		CategoryIncrement ci = null;
		if (trainingListeners != null) {
			for (ClassifierListener l : trainingListeners) {
				if (ci == null) {
					ci = new CategoryIncrement(category, getCategoryDocCount()
							.get(category));
					ci.setCountDelta(1);
				}
				l.handleCategoryUpdate(ci);
			}
		}

	}

	/**
	 * Direct port from Segaran's book, including method name
	 * 
	 * @param feature
	 *            the feature
	 * @param category
	 *            the category to query
	 * @return the number of times a feature has appeared in a category
	 */
	public double fcount(String feature, String category) {
		Map<String, Integer> fm = this.classifierDataModelFactory
				.getFeatureMap(feature);
		if (fm != null) {
			Integer count = fm.get(category);
			if (count != null) {
				return count;
			}
		}

		return 0.0;
	}

	/**
	 * Direct port from Segaran's book, including method name
	 * 
	 * @param category
	 *            the category to count items for
	 * @return the number of items in a category
	 */
	public double catcount(String category) {
		Integer count = getCategoryDocCount().get(category);
		if (count == null) {
			return 0.0;
		}
		return count.doubleValue();
	}

	/**
	 * Direct port from Segaran's book, including method name
	 * 
	 * @return the total number of items
	 */
	public double totalcount() {
		Map<String, Integer> map = getCategoryDocCount();
		return getTotalCount(map);
	}

	public double totalcount(String feature) {
		Map<String, Integer> classifierMap = this.classifierDataModelFactory
				.getFeatureMap(feature);
		if (classifierMap != null) {
			return getTotalCount(classifierMap);
		}
		return 0.0;
	}

	public double getTotalFeatureCount(String feature) {
		Map<String, Integer> classifierMap = this.classifierDataModelFactory
				.getFeatureMap(feature);
		if (classifierMap != null) {
			return getTotalCount(classifierMap);
		}
		return 0.0;
	}

	/**
	 * Direct port from Segaran's book, including method name.
	 * 
	 * @return the list of all getCategories
	 */
	public final Set<String> getCategories() {
		return Collections.unmodifiableSet(getCategoryDocCount().keySet());
	}

	public void train(Object item, String category) {
		Set<String> features = wordLister.getUniqueWords(item);

		for (String f : features) {
			incf(f, category);
		}
		incc(category);
	}

	/**
	 * Convenience method for descendant classes - aids in porting from
	 * Segaran's book.
	 * <p/>
	 * I want to change this method to use the arithmetic exception *only* if
	 * it's rare. It's possible that determining rarity might be even more
	 * expensive, though.
	 * 
	 * @param feature
	 *            the feature to consider
	 * @param category
	 *            the category
	 * @return the feature probability for the class
	 */
	protected double fprob(String feature, String category) {
		try {
			return fcount(feature, category) / catcount(category);
		} catch (ArithmeticException ae) {
			return 0;
		}
	}

	/**
	 * @param feature
	 *            the feature to consider
	 * @param category
	 *            the category
	 * @return the feature probability for the class
	 */
	public final double getFeatureProbability(String feature, String category) {
		return fprob(feature, category);
	}

	private double WEIGHT = 1.0;
	private double ASSUMED_PROBABILITY = 0.5;

	protected double weightedprob(String feature, String category) {
		return getWeightedProbability(feature, category);
	}

	/**
	 * @param feature
	 *            The feature to consider
	 * @param category
	 *            the category to consider weight for
	 * @return the weighted probability
	 */
	public double getWeightedProbability(String feature, String category) {
		double basicprob = getFeatureProbability(feature, category);
		double totals = getTotalFeatureCount(feature);
		return ((WEIGHT * ASSUMED_PROBABILITY) + (totals * basicprob))
				/ (WEIGHT + totals);
	}

	public final Map<String, Integer> getCategoryDocCount() {
		return classifierDataModelFactory.getCategoryCountMap();
	}

	private double getTotalCount(Map<String, Integer> map) {
		int totalCount = 0;
		for (Integer i : map.values()) {
			totalCount += i;
		}
		return totalCount;
	}

	private void incrementCategory(Map<String, Integer> map, String category) {
		Integer val = map.get(category);
		if (val != null) {
			map.put(category, val + 1);
		} else {
			map.put(category, 1);
		}
	}
}
